{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "13306023-d5b5-4c64-a54f-28749bcc7ee0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow_datasets.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import jax_md\n",
    "import numpy as np\n",
    "\n",
    "from ase.io import read\n",
    "from jax_md import units\n",
    "from typing import Dict\n",
    "\n",
    "from so3lr import to_jax_md\n",
    "from so3lr import So3lrPotential\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2954dbc7-4efa-4acc-a72c-ec98de21f3cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to perform simulations in float64 you have to call this before any JAX compuation\n",
    "jax.config.update('jax_enable_x64', True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "41f48f5d-2c9c-4c8f-8143-bb320a70253e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some helper functions used throughout the example notebook.\n",
    "\n",
    "# Default nose hoover chain parameters.\n",
    "def default_nhc_kwargs(\n",
    "    tau: jnp.float32, \n",
    "    overrides: Dict\n",
    ") -> Dict:\n",
    "    \n",
    "    default_kwargs = {\n",
    "        'chain_length': 3, \n",
    "        'chain_steps': 2, \n",
    "        'sy_steps': 3,\n",
    "        'tau': tau\n",
    "    }\n",
    "    \n",
    "    if overrides is None:\n",
    "        return default_kwargs\n",
    "  \n",
    "    return {\n",
    "      k: overrides.get(k, default_kwargs[k]) for k in default_kwargs\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fca27f5f-6f24-44d8-9c5c-446cabf8fb05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read some molecular structure. Here we read the DHA structure from the test directory.\n",
    "path_to_xyz = '../tests/test_data/water_64.xyz'\n",
    "atoms = read(path_to_xyz, index=-1)\n",
    "atoms.wrap()\n",
    "\n",
    "# Repeat in each direction.\n",
    "atoms = atoms * [2, 2, 2]\n",
    "\n",
    "species = jnp.array(atoms.get_atomic_numbers())\n",
    "masses = jnp.array(atoms.get_masses())\n",
    "num_atoms = len(species)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f8e80ad-5101-41c0-b08f-6815b04c4c54",
   "metadata": {},
   "source": [
    "# Periodic Boundary Conditions\n",
    "\n",
    "In contrast to the `nvt_jaxmd.ipynb` example, here the simulations are performed within a box such that periodic boundary conditions (PBCs) need to be applied. The functions calculating the `displacement` vectors between atoms and the `shift` function for updating atomic positions during simulation can be created using `jax_md.space.periodic_general`. It takes the simulation `box` as input, which can be represented with a single scalar for a box with equal side lengths `L`, by a vector `[Lx, Ly, Lz]` for an orthorombic cell, or by an upper triangular matrix for a general triclinic cell. Check also the `jax_md` docs [here](https://jax-md.readthedocs.io/en/main/jax_md.space.html#jax_md.space.periodic_general). The water box loaded above has a cubic box with equal side length such that we can represent it in `jax_md` by a single scalar.\n",
    "\n",
    "Since we are running simulations in fractional coordinates, we need to project the positions on the hypercube with side length 1. For the cubic case this is easy and corresponds to just dividing by the length of the lattice vectors. This is also described [here](https://jax-md.readthedocs.io/en/main/jax_md.space.html#jax_md.space.periodic_general)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53004b9f-96b3-49bf-9a8a-47a46919b140",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "box = [24.8]\n"
     ]
    }
   ],
   "source": [
    "# it is important that scalar is array of dim = 1. Otherwise weird behavior of jax_md.simulate.npt_box(state) which returns\n",
    "# a 3x3 array when passed a jnp.array of ndim = 0. This leads to TracerConversionError. Maybe open issue on jax_md?\n",
    "\n",
    "box = jnp.array(\n",
    "    [\n",
    "        np.array(atoms.get_cell())[0, 0]\n",
    "    ]\n",
    ")  \n",
    "\n",
    "print('box =', box)\n",
    "\n",
    "fractional_coordinates = True\n",
    "displacement, shift = jax_md.space.periodic_general(box=box, fractional_coordinates=fractional_coordinates)\n",
    "\n",
    "positions_init = jnp.array(atoms.get_positions())\n",
    "if fractional_coordinates:\n",
    "    positions_init = positions_init / box"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89dafff9-b64c-46da-ac97-244332fbd5c8",
   "metadata": {},
   "source": [
    "# Prepare JAX-MD Energy Function\n",
    "\n",
    "JAX-MD is build around the energy function and the neighborlist. Since we have short- and long-range neighbors, we have two neighborlists compared to the standard case of only a single one. We use the convenience interface `to_jax_md` provided within this repository, which returns the SO3LR energy function as well as the two neighborlists. It takes as input the `So3lrPotential` itself, as well as a displacement function which determines how update the atomic positions (see above)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "95ef5907-e38b-46ee-86ce-55f38b140fe8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:CheckpointMetadata file does not exist: /Users/thorbenfrank/Documents/git/so3lr/so3lr/params/checkpoints/ckpt_5150000/_CHECKPOINT_METADATA\n",
      "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jax-md partition: cell_size=0.1814516129032258, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=False\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n",
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "neighbor_fn, neighbor_fn_lr, energy_fn = to_jax_md(\n",
    "    potential=So3lrPotential(\n",
    "        dtype=jnp.float64  # or float32 for single precision\n",
    "    ),\n",
    "    displacement_or_metric=displacement,\n",
    "    box_size=box,\n",
    "    species=species,\n",
    "    capacity_multiplier=1.25,\n",
    "    buffer_size_multiplier_sr=1.25,\n",
    "    buffer_size_multiplier_lr=1.25,\n",
    "    minimum_cell_size_multiplier_sr=1.0,\n",
    "    disable_cell_list=False,\n",
    "    fractional_coordinates=fractional_coordinates\n",
    ")\n",
    "\n",
    "# Energy function.\n",
    "energy_fn = jax.jit(energy_fn)\n",
    "force_fn = jax.jit(jax_md.quantity.force(energy_fn))\n",
    "\n",
    "# Initialize the short and long-range neighbor lists.\n",
    "nbrs = neighbor_fn.allocate(\n",
    "    positions_init,\n",
    "    box=box\n",
    ")\n",
    "nbrs_lr = neighbor_fn_lr.allocate(\n",
    "    positions_init,\n",
    "    box=box\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8166e4e3-5df9-4a2e-a00c-065c8b61be4f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(-2681.06758185, dtype=float64)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check that the energy function is working.\n",
    "energy_fn(\n",
    "    positions_init, \n",
    "    neighbor=nbrs.idx, \n",
    "    neighbor_lr=nbrs_lr.idx, \n",
    "    box=box\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7102e75f-3f9e-4ed0-b2a7-b1401e36ad38",
   "metadata": {},
   "source": [
    "# Structure Relaxation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5beed8eb-6d01-4736-ba14-e92bf3930c02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step\tE\tFmax\n",
      "--------------------------\n",
      "jax-md partition: cell_size=0.1814516129032258, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n",
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\t-2781.35\t1.23\n",
      "1\t-2819.48\t1.91\n",
      "2\t-2847.90\t0.64\n",
      "3\t-2860.47\t0.59\n",
      "4\t-2868.68\t0.55\n"
     ]
    }
   ],
   "source": [
    "# For repeated execution of cell, delete the fire_state first, otherwise jit in jupyter environment \n",
    "# can meddle with things.\n",
    "try:\n",
    "    del fire_state\n",
    "except NameError:\n",
    "    pass\n",
    "\n",
    "min_cycles = 5\n",
    "min_steps = 10\n",
    "\n",
    "fire_init, fire_apply = jax_md.minimize.fire_descent(\n",
    "    energy_fn, \n",
    "    shift, \n",
    "    dt_start = 0.05, \n",
    "    dt_max = 0.1, \n",
    "    n_min = 2\n",
    ")\n",
    "\n",
    "fire_apply = jax.jit(fire_apply)\n",
    "fire_state = fire_init(\n",
    "    positions_init, \n",
    "    box=box, \n",
    "    neighbor=nbrs.idx,\n",
    "    neighbor_lr=nbrs_lr.idx\n",
    ")\n",
    "\n",
    "@jax.jit\n",
    "def step_fire_fn(i, fire_state):\n",
    "    \n",
    "    fire_state, nbrs, nbrs_lr = fire_state\n",
    "    \n",
    "    fire_state = fire_apply(\n",
    "        fire_state, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    nbrs = nbrs.update(\n",
    "        fire_state.position,\n",
    "        neighbor=nbrs.idx,\n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    nbrs_lr = nbrs_lr.update(\n",
    "        fire_state.position,\n",
    "        neighbor_lr=nbrs_lr.idx,\n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    return fire_state, nbrs, nbrs_lr\n",
    "\n",
    "print('Step\\tE\\tFmax')\n",
    "print('--------------------------')\n",
    "for i in range(min_cycles):\n",
    "    fire_state, nbrs, nbrs_lr = jax.lax.fori_loop(\n",
    "        0, \n",
    "        min_steps, \n",
    "        step_fire_fn, \n",
    "        (fire_state, nbrs, nbrs_lr)\n",
    "    )\n",
    "    \n",
    "    E = energy_fn(\n",
    "        fire_state.position, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx,\n",
    "        box=box\n",
    "    )\n",
    "\n",
    "    F = force_fn(\n",
    "        fire_state.position, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx,\n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    print('{}\\t{:.2f}\\t{:.2f}'.format(i, E, np.abs(F).max()))\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3a56f4d-6088-4c4e-bd95-5556ede90815",
   "metadata": {},
   "source": [
    "Print the delta between original and optimized positions for the first three atoms. The difference is in fractional coordinates. To go back to real space we can use `jax_md.space.transform`. Alternatively, one can use the `displacement` function returned by `jax_md.space.perdiodic_general` which always returns displacements in real space as it has information about if the input is in real or fractional coordinates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3c6d81e2-6eca-42e5-b2d8-0b68d8953c85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Difference in fractional coordinates: \n",
      "[[-0.00725509 -0.00101045  0.00172407]\n",
      " [-0.00976749  0.00028237  0.01174194]\n",
      " [-0.00162628  0.00326012  0.00356884]]\n",
      "\n",
      "\n",
      "Difference in real coordinates using transform function: \n",
      "[[-0.17992621 -0.02505908  0.04275685]\n",
      " [-0.24223367  0.00700275  0.29120013]\n",
      " [-0.04033171  0.08085088  0.08850724]]\n",
      "\n",
      "\n",
      "Difference in real coordinates from displacement function: \n",
      "[[-0.17992621 -0.02505908  0.04275685]\n",
      " [-0.24223367  0.00700275  0.29120013]\n",
      " [-0.04033171  0.08085088  0.08850724]]\n"
     ]
    }
   ],
   "source": [
    "print('Difference in fractional coordinates: ')\n",
    "print((positions_init - fire_state.position)[:3])\n",
    "print('\\n')\n",
    "print('Difference in real coordinates using transform function: ')\n",
    "print((jax_md.space.transform(box, positions_init) - jax_md.space.transform(box, fire_state.position))[:3])\n",
    "print('\\n')\n",
    "print('Difference in real coordinates from displacement function: ')\n",
    "print((jax.vmap(displacement)(positions_init, fire_state.position))[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7d40e05-eeda-4565-9fab-2b98c0f7523a",
   "metadata": {},
   "source": [
    "# Nose-Hoover Chain NPT simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0008c01b-dd71-480d-8d07-6ec684284d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simulation parameters\n",
    "\n",
    "timestep = 0.0005  # Time step in ps\n",
    "npt_cycles = 25  # Number of Cycles in the NVT.\n",
    "npt_steps = 5  # Number of NVT steps per cylce. The total number of MD steps equals nvt_cylces * nvt_steps\n",
    "\n",
    "T_init = 300  # Initial temperature in K.\n",
    "pressure_init = 1.01325  # Target pressure in bars. \n",
    "\n",
    "chain = 3  # Number of chains in the Nose-Hoover chain.\n",
    "chain_steps = 2  # Number of steps per chain.\n",
    "sy_steps = 3\n",
    "\n",
    "thermo = 100  # Thermostat value in the Nose-Hoover chain. \n",
    "baro = 1000  # Barostat value in the Nose-Hoover chain. \n",
    "\n",
    "# Dictionary with the NHC settings.\n",
    "new_nhc_kwargs = {\n",
    "    'chain_length': chain, \n",
    "    'chain_steps': chain_steps, \n",
    "    'sy_steps': sy_steps\n",
    "}\n",
    "\n",
    "# Convert to metal unit system.\n",
    "unit = units.metal_unit_system()\n",
    "\n",
    "timestep = timestep * unit['time']\n",
    "T_init = T_init * unit['temperature']\n",
    "pressure = pressure_init * unit['pressure']\n",
    "\n",
    "rng_key = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1511d5ea-60b5-4bac-b355-c07f982baf91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step\tKE\tPE\tTot. Energy\tTemp.\tH\ttime/steps\tInvariant drifts (H_i - H_0 , H - H_{i-1}) [meV/atom/ps]\n",
      "-------------------------------------------------------------------------------------------------------------------------------------\n",
      "jax-md partition: cell_size=0.1814516129032258, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>\n",
      "0\t43.24\t-2851.81\t-2808.570\t217.8\t-2808.578\t13.73\t98.58  ,  98.58\n",
      "5\t46.51\t-2855.30\t-2808.796\t234.2\t-2808.845\t6.70\t14.42  ,  -69.75\n",
      "10\t37.42\t-2845.94\t-2808.518\t188.5\t-2808.633\t6.73\t28.06  ,  55.36\n",
      "15\t42.70\t-2851.23\t-2808.529\t215.1\t-2808.807\t6.69\t9.71  ,  -45.35\n",
      "20\t35.57\t-2843.84\t-2808.273\t179.2\t-2808.738\t6.15\t11.36  ,  17.97\n",
      "25\t31.25\t-2839.25\t-2808.000\t157.4\t-2808.669\t6.11\t12.48  ,  18.07\n",
      "30\t38.19\t-2846.10\t-2807.911\t192.4\t-2808.808\t6.01\t5.52  ,  -36.26\n",
      "35\t38.39\t-2845.85\t-2807.457\t193.4\t-2808.675\t6.00\t9.14  ,  34.52\n",
      "40\t40.97\t-2848.27\t-2807.305\t206.4\t-2808.803\t6.04\t4.43  ,  -33.25\n",
      "45\t34.59\t-2841.54\t-2806.953\t174.2\t-2808.649\t6.24\t8.00  ,  40.11\n",
      "50\t40.32\t-2847.16\t-2806.838\t203.1\t-2808.749\t6.06\t4.90  ,  -26.08\n",
      "55\t42.96\t-2849.55\t-2806.589\t216.4\t-2808.743\t6.03\t4.62  ,  1.57\n",
      "60\t40.16\t-2846.48\t-2806.314\t202.3\t-2808.695\t6.29\t5.23  ,  12.54\n",
      "65\t40.01\t-2846.19\t-2806.173\t201.5\t-2808.748\t6.03\t3.86  ,  -13.93\n",
      "70\t39.99\t-2845.83\t-2805.848\t201.4\t-2808.664\t6.41\t5.08  ,  22.06\n",
      "75\t46.08\t-2851.75\t-2805.671\t232.1\t-2808.777\t6.06\t2.91  ,  -29.60\n",
      "80\t41.24\t-2846.52\t-2805.281\t207.7\t-2808.672\t6.01\t4.36  ,  27.49\n",
      "85\t40.80\t-2845.85\t-2805.046\t205.5\t-2808.711\t6.05\t3.55  ,  -10.15\n",
      "90\t43.05\t-2847.80\t-2804.754\t216.8\t-2808.710\t6.07\t3.37  ,  0.16\n",
      "95\t45.81\t-2850.22\t-2804.407\t230.7\t-2808.710\t6.09\t3.20  ,  -0.02\n",
      "100\t45.50\t-2849.59\t-2804.087\t229.2\t-2808.734\t6.30\t2.75  ,  -6.28\n",
      "105\t41.82\t-2845.51\t-2803.692\t210.7\t-2808.652\t6.08\t3.60  ,  21.41\n",
      "110\t45.90\t-2849.34\t-2803.439\t231.2\t-2808.722\t6.00\t2.65  ,  -18.31\n",
      "115\t47.17\t-2850.21\t-2803.033\t237.6\t-2808.693\t6.24\t2.86  ,  7.77\n",
      "120\t45.98\t-2848.63\t-2802.654\t231.6\t-2808.693\t6.98\t2.74  ,  -0.19\n",
      "Total_time:  874.4880802631378\n"
     ]
    }
   ],
   "source": [
    "# Choose Nose-Hoover thermostat.\n",
    "init_fn, apply_fn = jax_md.simulate.npt_nose_hoover(\n",
    "    energy_fn, \n",
    "    shift, \n",
    "    dt=timestep,\n",
    "    pressure=pressure, \n",
    "    kT=T_init,\n",
    "    barostat_kwargs=default_nhc_kwargs(baro * timestep, new_nhc_kwargs),\n",
    "    thermostat_kwargs=default_nhc_kwargs(thermo * timestep, new_nhc_kwargs)\n",
    ")\n",
    "\n",
    "apply_fn = jax.jit(apply_fn)\n",
    "init_fn = jax.jit(init_fn)\n",
    "\n",
    "# Initialize state using position and neigbhors structure relaxation.\n",
    "state = init_fn(\n",
    "    rng_key, \n",
    "    fire_state.position, \n",
    "    box=box, \n",
    "    neighbor=nbrs.idx, \n",
    "    neighbor_lr=nbrs_lr.idx,\n",
    "    kT=T_init,\n",
    "    mass=masses\n",
    ")\n",
    "\n",
    "@jax.jit\n",
    "def step_npt_fn(i, state):\n",
    "    state, nbrs, nbrs_lr, box = state\n",
    "    \n",
    "    state = apply_fn(\n",
    "        state,\n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        kT=T_init,\n",
    "        pressure=pressure\n",
    "    )\n",
    "    \n",
    "    box = jax_md.simulate.npt_box(state)\n",
    "    \n",
    "    nbrs = nbrs.update(\n",
    "        state.position, \n",
    "        neighbor=nbrs.idx, \n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    nbrs_lr = nbrs_lr.update(\n",
    "        state.position,\n",
    "        neighbor_lr=nbrs_lr.idx,\n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    return state, nbrs, nbrs_lr, box\n",
    "\n",
    "# Track total time and step times averaged over cycle.\n",
    "total_time = time.time()\n",
    "\n",
    "positions_md = []\n",
    "box_md = []\n",
    "\n",
    "print('Step\\tKE\\tPE\\tTot. Energy\\tTemp.\\tH\\ttime/steps\\tInvariant drifts (H_i - H_0 , H - H_{i-1}) [meV/atom/ps]')\n",
    "print('-------------------------------------------------------------------------------------------------------------------------------------')\n",
    "for i in range(npt_cycles):\n",
    "\n",
    "    if i == 0:\n",
    "        initial_H_0 = jax_md.simulate.npt_nose_hoover_invariant(\n",
    "            energy_fn, \n",
    "            state, \n",
    "            pressure=pressure,\n",
    "            kT=T_init,\n",
    "            neighbor=nbrs.idx, \n",
    "            neighbor_lr=nbrs_lr.idx\n",
    "        )\n",
    "    \n",
    "    # Calculate Hamiltonian of the extended NPT dynamics \n",
    "    initial_H = jax_md.simulate.npt_nose_hoover_invariant(\n",
    "        energy_fn, \n",
    "        state, \n",
    "        pressure=pressure,\n",
    "        kT=T_init,\n",
    "        neighbor=nbrs.idx,\n",
    "        neighbor_lr=nbrs_lr.idx\n",
    "    )\n",
    "\n",
    "    old_time = time.time()\n",
    "    \n",
    "    # Do `npt_steps` NPT steps.\n",
    "    new_state, nbrs, nbrs_lr, new_box = jax.block_until_ready(\n",
    "        jax.lax.fori_loop(\n",
    "            0, \n",
    "            npt_steps, \n",
    "            step_npt_fn, \n",
    "            (state, nbrs, nbrs_lr, box)\n",
    "        )\n",
    "    )\n",
    "    \n",
    "    new_time = time.time()\n",
    "    \n",
    "    # Check for overflor of both sr and lr neighbors.\n",
    "    if nbrs.did_buffer_overflow:\n",
    "        print('Neighbor list overflowed, reallocating.')\n",
    "        nbrs = neighbor_fn.allocate(state.position, box = box)\n",
    "        if nbrs_lr.did_buffer_overflow:\n",
    "            print('Long-range neighbor list also overflowed, reallocating.')\n",
    "            nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)\n",
    "    elif nbrs_lr.did_buffer_overflow:\n",
    "        print('Long-range neighbor list overflowed, reallocating.')\n",
    "        nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)\n",
    "    else:\n",
    "        state = new_state\n",
    "        box = new_box\n",
    "\n",
    "    # Calculate some quantities for printing\n",
    "    KE = jax_md.quantity.kinetic_energy(\n",
    "        momentum=state.momentum,\n",
    "        mass=state.mass\n",
    "    )\n",
    "    \n",
    "    PE = energy_fn(\n",
    "        state.position,\n",
    "        neighbor=nbrs.idx,\n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    T = jax_md.quantity.temperature(\n",
    "        momentum=state.momentum,\n",
    "        mass = state.mass\n",
    "    ) / unit['temperature']\n",
    "    \n",
    "    # Calculate initial total energy\n",
    "    H = jax_md.simulate.npt_nose_hoover_invariant(\n",
    "        energy_fn,\n",
    "        state,\n",
    "        pressure=pressure,\n",
    "        kT=T_init,\n",
    "        neighbor=nbrs.idx,\n",
    "        neighbor_lr=nbrs_lr.idx\n",
    "    )\n",
    "\n",
    "    energy_drift_h = (H - initial_H) * 1000 / (timestep / unit['time'] * npt_steps * num_atoms)\n",
    "    energy_drift_h_0 = (H - initial_H_0) * 1000 / (timestep / unit['time'] * npt_steps * (i + 1) * num_atoms)\n",
    "    \n",
    "    positions_md.append(np.array(state.position))\n",
    "    box_md.append(np.array(box))\n",
    "    \n",
    "    print(\n",
    "        f'{i*npt_steps}\\t{KE:.2f}\\t{PE:.2f}\\t{KE+PE:.3f}\\t{T:.1f}\\t{H:.3f}\\t{(new_time - old_time) / npt_steps:.2f}\\t{energy_drift_h_0:.2f}  ,  {energy_drift_h:.2f}'\n",
    "    )\n",
    "\n",
    "print('Total_time: ', time.time() - total_time)\n",
    "\n",
    "# Clear all caches\n",
    "jax.clear_caches()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a067db-f18e-43a3-bc13-584b3e58066b",
   "metadata": {},
   "source": [
    "# Visualization and IO\n",
    "\n",
    "For now, we concatenated the positions to a simple python list `positions_md`. This allows to directly visualize the trajectory and write the frames to `xyz` after the simulation. However, for production runs it might be neccessary to save the frames (and potentially other information) along the way. For large structures, using `ase.io.write` can be very very slow and reduce the simulation time by sevaral order of magnitudes. To learn how to efficiently write frames and other statistics during simulation, check out the `production_run.ipynb` example notebooks. It uses `.hdf5` files to efficiently perform save operations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "03148be2-8b2b-4707-8e01-9b1001489a36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to do visualization in the notebook, install nglview by doing `pip install nglview` in your virtualenv\n",
    "import nglview as nv\n",
    "from ase import Atoms\n",
    "\n",
    "atoms_traj = []\n",
    "for positions, b in zip(positions_md, box_md):\n",
    "    atoms_traj.append(\n",
    "        Atoms(\n",
    "            numbers=np.array(species), \n",
    "            positions=np.array(jax_md.space.transform(box=b, R=positions))  # transform back from fractional coordinates\n",
    "        ),\n",
    "    )\n",
    "\n",
    "nv.show_asetraj(atoms_traj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "87c6cea4-eec1-4c3f-a554-5afe00c72109",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the frames to xyz.\n",
    "from ase.io import write\n",
    "\n",
    "for frame in atoms_traj:\n",
    "    write( \n",
    "        'npt_water_md_trajectory.xyz',\n",
    "        frame,\n",
    "        append=True\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "so3lr",
   "language": "python",
   "name": "so3lr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
